{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "UViM color task",
      "provenance": [],
      "collapsed_sections": [],
      "private_outputs": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# Fetch big_vision repository and move it into the current workdir (import path).\n",
        "!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n",
        "!cp -R big_vision_repo/big_vision big_vision\n",
        "!pip install -qr big_vision/requirements.txt"
      ],
      "metadata": {
        "id": "sKZK6_QpVI_O"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "\n",
        "from big_vision.models.proj.uvim import vtt  # stage-II model\n",
        "from big_vision.models.proj.uvim import vit  # stage-I model\n",
        "\n",
        "from big_vision.models.proj.uvim import decode\n",
        "from big_vision.trainers.proj.uvim import colorization_task as task\n",
        "from big_vision.configs.proj.uvim import train_imagenet2012_colorization_pretrained as config_module\n",
        "\n",
        "import big_vision.pp.ops_image\n",
        "import big_vision.pp.ops_general\n",
        "import big_vision.pp.proj.uvim.pp_ops\n",
        "from big_vision.pp import builder as pp_builder\n",
        "\n",
        "config = config_module.get_config()\n",
        "res = 512\n",
        "seq_len = config.model.seq_len\n",
        "\n",
        "lm_model = vtt.Model(**config.model)\n",
        "oracle_model = vit.Model(**config.oracle.model)\n",
        "\n",
        "preprocess_fn = pp_builder.get_preprocess_fn(\n",
        "    'decode|resize(512)|'\n",
        "    'rgb_to_grayscale_to_rgb|value_range(-1,1)|'\n",
        "    'copy(inkey=\"image\",outkey=\"image_ctx\")')\n",
        "\n",
        "@jax.jit\n",
        "def predict_code(params, x, rng, temperature):\n",
        "  prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n",
        "  seqs, _, _ = decode.temperature_sampling(\n",
        "      params=params, model=lm_model, seed=rng,\n",
        "      inputs=x[\"image\"],\n",
        "      prompts=prompts,\n",
        "      temperature=temperature,\n",
        "      num_samples=1, eos_token=-1, prefill=False)\n",
        "  seqs = jnp.squeeze(seqs, axis=1)  # drop num_samples axis \n",
        "  return seqs - 1\n",
        "  \n",
        "@jax.jit\n",
        "def labels2code(params, x, ctx):\n",
        "  y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n",
        "  return aux[\"code\"]\n",
        "\n",
        "@jax.jit\n",
        "def code2labels(params, code, ctx):\n",
        "  logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n",
        "  return task.predict_outputs(logits, config.oracle)"
      ],
      "metadata": {
        "id": "QzThueWDzc7I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load checkpoints\n",
        "!gsutil cp -n gs://big_vision/uvim/color_stageI_params.npz gs://big_vision/uvim/color_stageII_params.npz .\n",
        "\n",
        "oracle_params, oracle_state = vit.load(None, \"color_stageI_params.npz\")\n",
        "oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n",
        "\n",
        "lm_params = vtt.load(None, \"color_stageII_params.npz\")\n",
        "lm_params = jax.device_put({\"params\": lm_params})"
      ],
      "metadata": {
        "id": "AEjRgshLa6Fp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Prepare set of images from coco/val2017:\n",
        "#  - https://cocodataset.org/\n",
        "import os\n",
        "import tensorflow as tf\n",
        "\n",
        "if not os.path.exists(\"val2017/\"):\n",
        "  !wget --no-clobber http://images.cocodataset.org/zips/val2017.zip\n",
        "  !unzip -uq val2017.zip\n",
        "\n",
        "dataset = tf.data.Dataset.list_files(\"val2017/*.jpg\", shuffle=True)\n",
        "dataset = dataset.map(lambda filename: {\"image\": tf.io.read_file(filename)})\n",
        "dataset = dataset.map(preprocess_fn)"
      ],
      "metadata": {
        "id": "BKifDDRnH_Ll"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Run the model in a few examples:\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "num_examples = 4\n",
        "data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n",
        "key = jax.random.PRNGKey(0)\n",
        "temperature = jnp.array(1.0)\n",
        "\n",
        "def render_example(image, prediction):\n",
        "  f, ax = plt.subplots(1, 2, figsize=(10, 10))\n",
        "  ax[0].imshow(image*0.5 + 0.5)\n",
        "  ax[0].axis(\"off\")\n",
        "  ax[1].imshow(prediction*0.5 + 0.5)\n",
        "  ax[1].axis(\"off\")\n",
        "\n",
        "for idx, batch in enumerate(data):\n",
        "  subkey = jax.random.fold_in(key, idx)\n",
        "  code = predict_code(lm_params, batch, key, temperature)\n",
        "  aux_inputs = task.input_pp(batch, config.oracle)\n",
        "  prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n",
        "  render_example(batch[\"image\"][0], prediction[\"color\"][0])"
      ],
      "metadata": {
        "id": "TuevCy33nuv3"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}